In 13-graph-geodesics.ipynb, we saw that the graph approach gives promising results on the cone dataset. In this notebook, we use the same approach and directly compare it with the discrete geodesic algorithm from 08-discrete-geodesics.ipynb. We will see that the graph approach works especially well on datasets, where the Euclidean interpolation leads to bad local minima. Also, we use the stochastic Riemannian metric this time and see that the graph approach still works well.
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
width=700,
height=500,
margin=go.Margin(l=60, r=60, b=40, t=20),
showlegend=False
)
config={'showLink': False}
colorscale=[[0.0, '#3595E3'], [1.0, '#2D4366']]
# Make results completely deterministic
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
We use 10,000 points for training and plot 1,000 points of these for the figures.
from sklearn.datasets import make_moons
from sklearn.preprocessing import MinMaxScaler
x, y = make_moons(n_samples=10000, noise=0.075, random_state=seed)
min_max_scaler = MinMaxScaler((-1, 1))
x = min_max_scaler.fit_transform(x)
x_plot, y_plot = x[:1000, :], y[:1000]
scatter_plot = go.Scatter(
x = x_plot[:, 0],
y = x_plot[:, 1],
mode = 'markers',
marker = {'color': y_plot, 'colorscale': colorscale}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
For our encoder and decoder, we use simple neural networks with 2 hidden layers of 10 neurons each.
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda, Flatten
from src.vae import VAE
from src.rbf import RBFLayer
input_dim = 2
latent_dim = 2
# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Sequential([
Dense(10, activation='softplus'),
Dense(10, activation='softplus')
])
enc_mean = Model(enc_input, Dense(latent_dim, activation='linear')(
enc_shared(enc_input)))
enc_var = Model(enc_input, Dense(latent_dim, activation='softplus')(
enc_shared(enc_input)))
# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
Dense(10, activation='softplus'),
Dense(10, activation='softplus'),
Dense(input_dim, activation='linear')
])
dec_mean = Model(dec_input, dec_mean(dec_input))
# Build the RBF network
num_centers = 20
a = 0.75
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))
vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=0.1)
history = vae.model.fit(x,
shuffle=True,
epochs=20,
batch_size=100,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'])]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
iplot(go.Figure(data=data, layout=plot_layout), config=config)
# Display a 2D plot of the classes in the latent space
_, encoded_mean, _ = vae.encoder.predict(x_plot)
latent_scatter_plot = go.Scatter(
x = encoded_mean[:, 0],
y = encoded_mean[:, 1],
mode = 'markers',
marker = {'color': y_plot, 'colorscale': colorscale},
hoverinfo = 'text',
text = np.arange(len(encoded_mean))
)
data = [latent_scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
from sklearn.cluster import KMeans
# Find the centers of the latent representations
_, encoded_train, _ = vae.encoder.predict(x)
kmeans = KMeans(n_clusters=num_centers, random_state=0).fit(encoded_train)
centers = kmeans.cluster_centers_
# Visualize the centers
center_plot = go.Scatter(
x = centers[:, 0],
y = centers[:, 1],
mode = 'markers',
marker = {'color': 'red'}
)
data = [latent_scatter_plot, center_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
This follows equation 11 from the Latent Space Oddity paper.
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_train, kmeans.predict(encoded_train)):
clustering[c_i].append(z_i)
bandwidths = []
for c_i, cluster in clustering.items():
if cluster:
diffs = np.array(cluster) - centers[c_i]
avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
bandwidth = 0.5 / (a * avg_dist)**2
else:
bandwidth = 0
bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)
using the centers and bandwidths computed above.
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])
history = vae.model.fit(x,
shuffle=True,
epochs=200,
batch_size=100,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'])]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
iplot(go.Figure(data=data, layout=plot_layout), config=config)
# Feed the encoder mean to the decoder
decoded, decoded_mean, decoded_var = vae.decoder.predict(encoded_mean)
scatter_plot = go.Scatter(
x = decoded_mean[:, 0],
y = decoded_mean[:, 1],
mode = 'markers',
marker = {'color': y_plot, 'colorscale': colorscale}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
z_start_index = 39
z_end_index = 972
z_start, z_end = encoded_mean[z_start_index], encoded_mean[z_end_index]
task_plot = go.Scatter(
x = np.array([z_start[0], z_end[0]]),
y = np.array([z_start[1], z_end[1]]),
mode = 'markers',
marker = {'color': 'red'}
)
together with the two points
# Get the mean and std predictors
from src.util import wrap_model_in_float64
_, mean, var = vae.decoder.output
std = Lambda(tf.sqrt)(var)
dec_mean = Model(vae.decoder.input, Flatten()(mean))
dec_std = Model(vae.decoder.input, Flatten()(std))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)
from src.plot import plot_magnification_factor
sess = tf.keras.backend.get_session()
heatmap_z1 = np.linspace(-3, 3, 100)
heatmap_z2 = np.linspace(-3, 3, 100)
heatmap = plot_magnification_factor(sess,
heatmap_z1,
heatmap_z2,
dec_mean,
dec_std,
additional_data=[latent_scatter_plot, task_plot],
layout=layout,
log_scale=True)
Now that we have our task, we can use the discrete geodesic algorithm from Shao et al. to approximate the geodesic / compute the Riemannian distance between the two points. The algorithm is implemented and explained in 08-discrete-geodesics.ipynb.
%%time
from src.discrete import find_geodesic_discrete
discrete_curve, iterations = find_geodesic_discrete(sess, z_start, z_end,
dec_mean,
std_generator=dec_std,
num_nodes=1000,
learning_rate=0.001,
max_steps=500,
save_every=50,
log_every=50)
print('-' * 20)
from src.plot import plot_latent_curve_iterations
plot_latent_curve_iterations(iterations,
[heatmap, latent_scatter_plot, task_plot],
layout,
step_size=500)
We see that the curve length steadily decreases for the discrete geodesic algorithm. However, in the plot above we also see that the algorithm "jumps" over regions of large magnification factors in order to avoid "paying" for the regions. You can zoom in on the green points in the plot above to see the jumps. Therefore, the length estimate of the discrete algorithm is flawed. We need a function that removes these jumps and takes equidistant steps in the latent space in order to measure the Riemannian length of a function:
def interpolate(curve, num_nodes):
# Take equistant steps in the latent space
latent_lengths = np.linalg.norm(curve[1:] - curve[:-1], axis=1)
latent_length = np.sum(latent_lengths)
step_size = latent_length / (num_nodes - 1)
interpolation = [curve[0]]
i_curve_node = 1
# Construct the curve step by step
for i_node in range(num_nodes-2):
current_step_length = 0
position = interpolation[-1]
# Find the next interval on the curve
curve_step_length = np.linalg.norm(curve[i_curve_node] - position)
while(step_size - current_step_length >= curve_step_length):
# This curve step fits completely into our current step
current_step_length += curve_step_length
position = curve[i_curve_node]
i_curve_node += 1
curve_step_length = np.linalg.norm(curve[i_curve_node] - position)
# Take the missing partial step
relative_step = (step_size - current_step_length) / curve_step_length
next_node = position + relative_step * (curve[i_curve_node] -
position)
interpolation.append(next_node)
interpolation.append(curve[-1])
return np.array(interpolation)
Below, we plot the corrected discrete curve with a small number of nodes. For measuring the actual Riemannian length of a curve precisely, we would take 100+ interpolation points. This plot is just to show that the interpolation function works properly.
interpolate_test = interpolate(discrete_curve, 20)
plot_latent_curve_iterations([interpolate_test],
[heatmap, latent_scatter_plot, task_plot],
layout)
Given our interpolate function that corrects the curve, we can now use it to measure the actual Riemannian length of the discrete algorithm's solution. The algorithm started with a Euclidean interpolation (length 55.2), reached a length of 46.6 after 50 steps and then further decreased the curve length to 26.1 after 5000 steps. However, below we see that the actual curve length of the final solution after 5000 steps is 48.6. Therefore, most of the steps of the discrete algorithm seem to have no effect on the actual curve length.
from src.util import get_length_op, get_lengths_op
curve_ph = tf.placeholder(tf.float64, [None, 2])
lengths_op = get_lengths_op(curve_ph, dec_mean, dec_std)
lengths_op = tf.squeeze(lengths_op)
def evaluate_curve(curve, num_nodes=200):
curve = interpolate(curve, num_nodes)
lengths = sess.run(lengths_op, feed_dict={curve_ph: curve})
length = np.sum(lengths)
print('Curve length: ', length)
evaluate_curve(discrete_curve)
graph_points = encoded_mean
extensions = [graph_points + np.random.randn(*graph_points.shape)
for _ in range(2)]
graph_points = np.concatenate([graph_points] + extensions)
print(graph_points.shape)
graph_plot = go.Scatter(
x = graph_points[:, 0],
y = graph_points[:, 1],
mode = 'markers',
marker = {'color': '#EE7830', 'size': 4}
)
data = [graph_plot, heatmap]
iplot(go.Figure(data=data, layout=layout), config=config)
To get the nearest neighbors of each point, we use the get_neighbors function from src.graph. It is explained and defined in 13-graph-geodesics.ipynb.
Given the get_neighbors function, compute the Riemannian distance between each point and each of its neighbors. We approximate the Riemannian distance with a single midpoint for integration: $\int_0^1 \left\| J_{\gamma_t} \dot{\gamma}_t \right\| \mathrm{d}t \approx \left\| J_{\gamma_t} \dot{\gamma}_t \right\|$
from src.graph import get_neighbors
from src.util import get_metric_op
import networkx as nx
from tqdm import tqdm
point_ph = tf.placeholder(tf.float32, [2])
metric_op = get_metric_op(point_ph, dec_mean, dec_std)
# Add all nodes from the graph points
graph = nx.Graph()
for i_point, point in enumerate(graph_points):
graph.add_node(i_point, pos=point)
k = 4
# Compute the Riemannian distance between the kNNs
for i_point, point in enumerate(tqdm(graph_points)):
neighbor_indices = get_neighbors(i_point, graph_points, k)
for i_neighbor in neighbor_indices:
neighbor = graph_points[i_neighbor]
# Compute the Riemannian distance with a single midpoint
middle = point + 0.5 * (neighbor - point)
velocity = neighbor - point
# Get the Riemannian metric at the midpoint
metric = sess.run(metric_op, feed_dict={point_ph: middle})
length = velocity.T.dot(metric).dot(velocity)
length = np.sqrt(length)
graph.add_edge(i_point, i_neighbor, weight=length)
and the relative weight of the edges (Riemannian length divided by Euclidean length). Green means a low relative weight, red means a large relative weight.
from src.plot import plot_graph_with_edge_colors
x_range = [-3., 3.]
y_range = [-3., 3.]
subnodes = []
for node in graph.nodes():
pos = graph.node[node]['pos']
if (x_range[0] <= pos[0] <= x_range[1] and
y_range[0] <= pos[1] <= y_range[1]):
subnodes.append(node)
subgraph = graph.subgraph(subnodes)
graph_plot = plot_graph_with_edge_colors(graph, layout=layout)
between the two points from above.
%%time
from networkx.algorithms.shortest_paths.generic import shortest_path
path = shortest_path(graph, z_start_index, z_end_index, weight='weight')
length = 0
for source, sink in zip(path[:-1], path[1:]):
length += graph[source][sink]['weight']
print('Path length:', length)
from src.plot import plot_graph
# Construct a subgraph from the path
path_graph = nx.Graph()
for point in path:
path_graph.add_node(point, pos=graph_points[point])
for source, sink in zip(path[:-1], path[1:]):
weight = graph[source][sink]['weight']
path_graph.add_edge(source, sink, weight=weight)
_ = plot_graph(path_graph, layout=layout, edge_width=0,
additional_data=graph_plot)
Since we only computed the Riemannian distance for each edge using a single midpoint, the graph length is not exactly correct. It is not as strongly biased as the discrete geodesic algorithm's length estimate, but we should measure it as well with the interpolate function for a fair comparison.
graph_curve = graph_points[path]
evaluate_curve(graph_curve)
# Plot the graph curve
graph_curve_plot = go.Scatter(
x=graph_curve[:, 0],
y=graph_curve[:, 1],
mode='lines',
line={'width': 5, 'color': '#3CA64D'}
)
# Plot the discrete curve
discrete_curve_plot = go.Scatter(
x=discrete_curve[:, 0],
y=discrete_curve[:, 1],
mode='lines',
line={'width': 5, 'color': '#d32f2f'}
)
data = [heatmap, latent_scatter_plot, graph_curve_plot, task_plot,
discrete_curve_plot]
iplot(go.Figure(data=data, layout=layout), config=config)